"""
项目运行入口
"""

from text_classification.models import WTFIDF
from dataset import AclImdbDataset
import jieba
import numpy as np


def tfidf_train(_train_data, _train_label, _test_data, _test_label):
    # 获取模型
    tfidf_model = WTFIDF(participle_utils=jieba.lcut)

    # 训练
    tfidf_model.fit(_train_data, _train_label)
    # 预测
    predict_val = tfidf_model.predict(_test_data)
    print(f'TFIDF模型准确率: {np.sum(_test_label == predict_val) / len(_test_label)}')


if __name__ == '__main__':

    # 加载数据集
    acl_imdb_dataset = AclImdbDataset()

    # 设置数据大小
    train_size = int(acl_imdb_dataset.train_data_size / 100)
    test_size = int(acl_imdb_dataset.test_data_size / 100)
    seed_num = 66

    # 打乱数据
    np.random.seed(seed_num)
    train_idx = np.random.randint(0, acl_imdb_dataset.train_data_size, train_size)
    test_idx = np.random.randint(0, acl_imdb_dataset.test_data_size, test_size)

    # 截取数据
    train_data = acl_imdb_dataset.train_data['content'].iloc[train_idx].to_numpy()
    train_label = acl_imdb_dataset.train_data['label'].iloc[train_idx].to_numpy()
    test_data = acl_imdb_dataset.test_data['content'].iloc[test_idx].to_numpy()
    test_label = acl_imdb_dataset.test_data['label'].iloc[test_idx].to_numpy()

    tfidf_train(train_data, train_label, test_data, test_label)

